{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "886cf794-b413-49f0-a78d-0003664abead",
    "_uuid": "ecb2ba94-3d48-4876-8200-65fa763b95ed",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "L1X_zrhblRB2",
    "outputId": "917e227e-6fb9-437a-a46c-305e609bcbd4"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from keras import initializers\n",
    "from keras.datasets import cifar10, mnist\n",
    "from keras.initializers import RandomNormal\n",
    "from keras.layers import (BatchNormalization, Conv2D, Conv2DTranspose, Dense,\n",
    "                          Dropout, Flatten, Input, Reshape, UpSampling2D,\n",
    "                          ZeroPadding2D)\n",
    "from keras.layers.advanced_activations import LeakyReLU\n",
    "from keras.models import Model, Sequential\n",
    "from keras.optimizers import Adam\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "56a938fc-fa20-4972-9ec9-9c6769fa6faf",
    "_uuid": "26f698d5-784c-4882-b3b4-10cfcdf493e3",
    "colab": {},
    "colab_type": "code",
    "id": "cqkoXzNElRB5"
   },
   "outputs": [],
   "source": [
    "# Consistent results\n",
    "np.random.seed(1337)\n",
    "\n",
    "# The dimension of z\n",
    "noise_dim = 100\n",
    "\n",
    "batch_size = 16\n",
    "steps_per_epoch = 312 # 50000 / 16\n",
    "epochs = 800\n",
    "\n",
    "save_path = 'dcgan-images'\n",
    "\n",
    "img_rows, img_cols, channels = 32, 32, 3\n",
    "\n",
    "optimizer = Adam(0.0002, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "98a90dbb-55a7-4bf1-acd2-096ec9adf718",
    "_uuid": "d1eca82d-ba97-4d46-b86d-485f2d86a62c",
    "colab": {},
    "colab_type": "code",
    "id": "k1bONiz8lRB8"
   },
   "outputs": [],
   "source": [
    "# Create path for saving images\n",
    "if save_path != None and not os.path.isdir(save_path):\n",
    "    os.mkdir(save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "cccee622-d163-4927-9a5e-a2aabc80bc50",
    "_uuid": "c3da3020-b247-4ce2-af16-a4c2f193900f",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "tc0gEPQFlRB-",
    "outputId": "68d2319a-09da-46a8-f8bb-d24a738673a1"
   },
   "outputs": [],
   "source": [
    "# Load and pre-process data\n",
    "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
    "# Normalize to between -1 and 1\n",
    "x_train = (x_train.astype(np.float32) - 127.5) / 127.5\n",
    "\n",
    "# Reshape and only save cat images\n",
    "x_train = x_train[np.where(y_train == 0)[0]].reshape((-1, img_rows, img_cols, channels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "8cac0a57-f50d-4f2e-a276-5c742b5a74cc",
    "_uuid": "245e234f-0e99-4a17-8267-24b28f504306",
    "colab": {},
    "colab_type": "code",
    "id": "uEF5BN8flRCB"
   },
   "outputs": [],
   "source": [
    "def create_generator():\n",
    "    generator = Sequential()\n",
    "    \n",
    "    # Starting size\n",
    "    d = 4\n",
    "    generator.add(Dense(d*d*256, kernel_initializer=RandomNormal(0, 0.02), input_dim=noise_dim))\n",
    "    generator.add(LeakyReLU(0.2))\n",
    "    # 4x4x256\n",
    "    generator.add(Reshape((d, d, 256)))\n",
    "    \n",
    "    # 8x8x128\n",
    "    generator.add(Conv2DTranspose(128, (4, 4), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    generator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    # 16x16*128\n",
    "    generator.add(Conv2DTranspose(128, (4, 4), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    generator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    # 32x32x128\n",
    "    generator.add(Conv2DTranspose(128, (4, 4), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    generator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    # 32x32x3\n",
    "    generator.add(Conv2D(channels, (3, 3), padding='same', activation='tanh', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    \n",
    "    generator.compile(loss='binary_crossentropy', optimizer=optimizer)\n",
    "    return generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "5cfad0f4-6066-40c5-82a1-23c6fe0953cf",
    "_uuid": "9bdbbbb0-33dc-49d2-9dd2-127a96909ad6",
    "colab": {},
    "colab_type": "code",
    "id": "UplAGet9lRCD"
   },
   "outputs": [],
   "source": [
    "def create_discriminator():\n",
    "    discriminator = Sequential()\n",
    "    \n",
    "    discriminator.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=RandomNormal(0, 0.02), input_shape=(img_cols, img_rows, channels)))\n",
    "    discriminator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    discriminator.add(Conv2D(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    discriminator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    discriminator.add(Conv2D(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    discriminator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    discriminator.add(Conv2D(256, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))\n",
    "    discriminator.add(LeakyReLU(0.2))\n",
    "    \n",
    "    discriminator.add(Flatten())\n",
    "    discriminator.add(Dropout(0.4))\n",
    "    discriminator.add(Dense(1, activation='sigmoid', input_shape=(img_cols, img_rows, channels)))\n",
    "    \n",
    "    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)\n",
    "    return discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "ae02d478-086e-4ebb-a785-3a82b20932c1",
    "_uuid": "86e3c88d-614c-4c03-b332-3331d4fd8651",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 377
    },
    "colab_type": "code",
    "id": "KBjg3vMIlRCF",
    "outputId": "6b9ac932-3951-4956-e769-ce3a8f58381c",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "discriminator = create_discriminator()\n",
    "generator = create_generator()\n",
    "\n",
    "# Make the discriminator untrainable when we are training the generator.  This doesn't effect the discriminator by itself\n",
    "discriminator.trainable = False\n",
    "\n",
    "# Link the two models to create the GAN\n",
    "gan_input = Input(shape=(noise_dim,))\n",
    "fake_image = generator(gan_input)\n",
    "\n",
    "gan_output = discriminator(fake_image)\n",
    "\n",
    "gan = Model(gan_input, gan_output)\n",
    "gan.compile(loss='binary_crossentropy', optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "b8aa339a-5e8a-4268-a10f-de497a766be9",
    "_uuid": "2029c1da-47e2-44a5-b273-11e13e51a216",
    "colab": {},
    "colab_type": "code",
    "id": "E4K79VfLlRCH"
   },
   "outputs": [],
   "source": [
    "# Display images, and save them if the epoch number is specified\n",
    "def show_images(noise, epoch=None):\n",
    "    generated_images = generator.predict(noise)\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    \n",
    "    for i, image in enumerate(generated_images):\n",
    "        plt.subplot(10, 10, i+1)\n",
    "        if channels == 1:\n",
    "            plt.imshow(np.clip(image.reshape((img_rows, img_cols)), 0.0, 1.0), cmap='gray')\n",
    "        else:\n",
    "            plt.imshow(np.clip(image.reshape((img_rows, img_cols, channels)), 0.0, 1.0))\n",
    "        plt.axis('off')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    if epoch != None and save_path != None:\n",
    "        plt.savefig(f'{save_path}/gan-images_epoch-{epoch}.png')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "d5e891f3-a40b-4796-a8fa-4a53acd8af54",
    "_uuid": "3a188e82-2d2d-4d31-87f5-9798bd5fbd1b",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "myT68aLJlRCJ",
    "outputId": "d0c7702a-8969-46e2-c1ad-09e05fdb0ea8",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Constant noise for viewing how the GAN progresses\n",
    "static_noise = np.random.normal(0, 1, size=(100, noise_dim))\n",
    "\n",
    "# Training loop\n",
    "for epoch in range(epochs):\n",
    "    for batch in range(steps_per_epoch):\n",
    "        noise = np.random.normal(0, 1, size=(batch_size, noise_dim))\n",
    "        real_x = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]\n",
    "\n",
    "        fake_x = generator.predict(noise)\n",
    "\n",
    "        x = np.concatenate((real_x, fake_x))\n",
    "\n",
    "        disc_y = np.zeros(2*batch_size)\n",
    "        disc_y[:batch_size] = 0.9\n",
    "\n",
    "        d_loss = discriminator.train_on_batch(x, disc_y)\n",
    "\n",
    "        y_gen = np.ones(batch_size)\n",
    "        g_loss = gan.train_on_batch(noise, y_gen)\n",
    "\n",
    "    print(f'Epoch: {epoch} \\t Discriminator Loss: {d_loss} \\t\\t Generator Loss: {g_loss}')\n",
    "    if epoch % 2 == 0:\n",
    "        show_images(static_noise, epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "fafa4cae-38eb-433f-bad5-87a1af0275aa",
    "_uuid": "29dbb287-e9d2-4b96-b41d-514fc94d9654",
    "colab": {},
    "colab_type": "code",
    "id": "6-yBf84llRCN"
   },
   "outputs": [],
   "source": [
    "# Turn the training process into a GIF\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "image_names = os.listdir(save_path)\n",
    "\n",
    "frames = []\n",
    "for image in sorted(image_names, key=lambda name: int(''.join(i for i in name if i.isdigit()))):\n",
    "    frames.append(Image.open(save_path + '/' + image))\n",
    "\n",
    "frames[0].save('gan_training.gif', format='GIF', append_images=frames[1:], save_all=True, duration=80, loop=0)\n",
    "\n",
    "discriminator.save('dcdiscriminator.h5')\n",
    "generator.save('dcgenerator.h5')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "GANs Tutorial.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
